/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/** \file

    Declaration of the xfer, ray_stack_object and propagation_history
    classes. The implementations are in xfer_impl.h */

#ifndef __xfer__
#define __xfer__

#include "config.h"
#include "mcrx-types.h"
#include "ray.h"
#include "optical.h"
#include <vector>
#include "mcrx-debug.h"
#include "boost/function.hpp"
#include "emission.h"
#include <tbb/concurrent_queue.h>
#include <tbb/atomic.h>
#include "mpi_master.h"

#ifdef HAVE_BOOST_SERIALIZATION
#include <boost/serialization/serialization.hpp>
#endif

namespace mcrx {
  template<typename, typename> class queue_item;

  template <typename T_ray, typename T_cell_tracker> 
  queue_item<T_ray, T_cell_tracker> 
  independent_copy(const queue_item<T_ray, T_cell_tracker>& rhs);

  template<typename> class propagation_history;
  template <typename, typename> class xfer;
  // and some forward declarations...
  template <typename> class emergence;

  /// Names of queue actions, for debugging outputs
  const char* const action_strings [] = 
    {"scatter", "propagate", "emerge", "integrate", "column", "invalid"};
}


/** The objects used to put rays, along with the necessary information
    to continue tracing a ray, on the queue.*/
template <typename T_ray, typename T_cell_tracker> 
class mcrx::queue_item { 
public: 
  /** The ray_queue_action enum indicates what action the ray should
      be continued with. scatter means the ray is to be scattered
      immediately, propagate that it should have an interaction point
      drawn and be propagated, emerge means it's a camera ray that
      should be propagated through the box and have its column density
      updated. */
  enum action_type { scatter, propagate, emerge, integrate, emission, invalid };

  /// The action to take when the ray is popped.
  action_type action_;
  /** Extra parameter. For camera actions, the camera the ray is
      headed for. For scatter actions, the species that should scatter
      the ray. */
  int param_;
  T_cell_tracker cell_;  ///< The cell the ray is located in.
  T_ray ray_; ///< The ray data.
  typename T_ray::T_lambda norm_; ///< The ray normalization.

  queue_item() : action_(invalid) {};
  /** The constructor makes reference copies of the norm array, and
      also the ray. Care should be taken when pushing a
      queue_item to make independent copies of these objects if
      necessary. */
  queue_item (action_type a, const T_ray& r, T_cell_tracker c,
	      const typename T_ray::T_lambda& norm, int par=-1) :
    action_(a), param_(par), cell_(c), ray_(r), norm_(norm) {};

  /** Assignment operator does a "shallow" assignment of the array
      components. */
  queue_item& operator=(const queue_item& rhs) {
    action_=rhs.action_;
    param_ = rhs.param_;
    cell_ = rhs.cell_;
    ray_.shallow_assign(rhs.ray_);
    reference_copy(norm_, rhs.norm_);
    return *this;
  };

  /** This ordering of queue_items will make emerge rays come out first
      in a priority_queue. */
  bool operator<(const queue_item& rhs) const {
    return action_ < rhs.action_; }

#ifdef HAVE_BOOST_SERIALIZATION
  friend class boost::serialization::access;
  /** For efficiency, we only serialize the norm_ data, not the full
      array. We also assume it is the same length as the ray
      intensity. Note that we can't serialize the enum as an enum
      because the default enum serializer uses a temporary and thus
      won't work with skeleton/content. */
  template<class T_arch>
  void save(T_arch& ar, const unsigned int version) const {
    assert(sizeof(action_)==sizeof(int));
    ar << *reinterpret_cast<const int*>(&action_)
       << param_ << cell_ << ray_;
    
    assert(same_size(norm_, ray_.intensity()));
    serialize_data_only(norm_, ar);
  };
  template<class T_arch>
  void load(T_arch& ar, const unsigned int version) {
    ar >> *reinterpret_cast<int*>(&action_)
       >> param_ >> cell_ >> ray_;
    
    resize_like(norm_, ray_.intensity());
    unserialize_data_only(norm_, ar);
  };
  BOOST_SERIALIZATION_SPLIT_MEMBER()
#endif

  /** Returns an invalid item. For those, the only thing that matters
      is the action_type. */
  static queue_item make_invalid() {
    return queue_item(); };
};


/** Specialization of the independent_copy function for
    queue_items. Returns a queue item where the norm array and the ray
    has been independently copied. */
template <typename T_ray, typename T_cell_tracker> 
mcrx::queue_item<T_ray, T_cell_tracker> 
mcrx::independent_copy(const queue_item<T_ray, T_cell_tracker>& rhs)
{
  return queue_item<T_ray, T_cell_tracker>(rhs.action_, 
					   independent_copy(rhs.ray_),
					   rhs.cell_,
					   independent_copy(rhs.norm_),
					   rhs.param_);
}


/** Stores history of propagation through cells when using forced
    scattering.  The propagation_history class contains a path length,
    an accumulated optical depth and a pointer to the cell traversed,
    which is put into a vector when the ray is propagated through the
    cells.  In this way we can quickly find where the forced
    scattering occurs. */
template <typename T_cell_tracker>
class mcrx::propagation_history {
private:
  T_densities n_; ///< Traversed column density.
  T_float length_; ///< Accumulated path length.
  T_cell_tracker cell_; ///< Pointer to the grid cell traversed.
public:
  propagation_history (const T_densities& n, T_float l, const T_cell_tracker& c):
    n_(n), length_(l), cell_ (c){};
  const T_densities& n () const {return n_;};
  T_float length () const {return length_;};
  const T_cell_tracker& cell () const {return cell_;};
  /// Comparison operator so that we can use binary search.
  bool operator< (const mcrx::propagation_history<T_cell_tracker>& rhs) const {
    return length() < rhs.length();};
};

/** The xfer class is responsible for performing the ray tracing.  It
    relies on the emission object to create rays, uses the grid object
    to step the ray grid cell by grid cell, and uses the emergence
    object to register the rays leaving the volume. This class can run
    in both monochromatic and polychromatic mode, depending on the
    dust_model_type set. The details of the algorithm used to trace
    the rays is described in the Sunrise papers. */
template <typename dust_model_type, typename grid_type> 
class mcrx::xfer {
public:
  typedef dust_model_type T_dust_model;
  typedef grid_type T_grid;
  typedef typename T_dust_model::T_rng_policy T_rng_policy;
  typedef mcrx::T_float T_float;
  typedef typename T_dust_model::T_scatterer T_scatterer;
  typedef typename T_dust_model::T_lambda T_lambda;
  typedef typename T_dust_model::T_biaser T_biaser;
  typedef ray<T_lambda> T_ray;
  typedef emission<typename T_dust_model::T_chromatic_policy::T_base_policy, 
			    T_rng_policy> T_emission;
  typedef emergence<T_lambda> T_emergence;
  typedef typename T_grid::T_cell T_cell;
  typedef typename T_grid::T_cell_tracker T_cell_tracker;
  typedef propagation_history<T_cell_tracker>  T_propagation_history;

  typedef queue_item<T_ray, typename T_cell_tracker::T_code> T_queue_item;
  //typedef tbb::concurrent_priority_queue<T_queue_item> T_ray_queue;
  typedef tbb::concurrent_bounded_queue<T_queue_item> T_ray_queue;
  typedef typename T_queue_item::action_type T_action;

  typedef std::pair<int,blitz::TinyVector<T_float, 2> > T_pixel_queue_item;
  typedef tbb::concurrent_queue<T_pixel_queue_item> T_pixel_queue;

  enum hpm_stages {Worker, Mainloop, Forced, Propagate, Intensities, Unforced, 
		   Ray_send, Send_wait, Emit, Locate, Scatter, Emerge, 
		   Camera_ray };
protected:

  /// The Grid
  T_grid& g;

  /// Pointer to the emergence object, collecting emerging rays. 
  T_emergence& emergence_;

  /** Pointer to either the global or the local emergence, whichever
      is used, so we don't have to worry about which one it is in the
      algorithm. */
  T_emergence* current_emergence_;

  /** The emission object emits rays.  Note that we make a COPY of the
      object here to ensure each thread has its own. The emission
      classes have to be reasonable about making shallow copies to
      large chunks of data. \todo Do we really need a copy here? There
      should be no problem making the emission class reentrant. */
  const T_emission& emission_;

  /** The dust_model object holds the scatterers and is responsible
      for translating the densities in the grid to optical depth. */
  const T_dust_model& dust_model;

  /// \name Ray tracing variables
  ///@{
  T_cell_tracker current_cell; ///< The grid_cell the ray is currently in.
  T_float scattering_reftau_; ///< Optical depth where scattering will happen.
  const T_float ray_min_i_; ///< Smallest ray intensity we care about.
  const T_float ray_max_i_; ///< Largest ray intensity allowed.
  T_ray ray_;
  T_biaser b_; 
  ///@}

  /** \name Cache variables
      To avoid repeatedly allocating memory for temporary arrays in
      shoot and propagate, we make them class members. (And actually
      they are made to point to the same 2D array so they are also
      localized in memory.)
  */
  ///@{
  /// The intensity normalization of the current ray.
  T_lambda norm_;
  /// Incoming intensity in the grid cell when adding intensities in shoot().
  T_lambda i_in_;
  /// Outgoing intensity in the grid cell when adding intensities in shoot().
  T_lambda i_out_;
  /// Optical depth through the grid cell when adding intensities in shoot().
  T_lambda dtau_;
  /// The optical depth to grid edge when forced scattering.
  T_lambda tau_exit_;
  /// (1-exp(-tau_exit)) for forced scattering.
  T_lambda onemexptau_exit_;
  /// The optical depth to the point of scattering.
  T_lambda scattering_tau_;
  /// The original intensity of the ray, saved for restoring later.
  T_lambda original_intensity_;
  /// The intensity to the camera in emerge_from_point.
  T_lambda intensity_to_camera_;
  /// Column density calculated in propagate().
  T_densities dn_; 
  ///@}

  /** The propagation_history vector, keeping track of which cells
      have been traversed since last event. */
  std::vector<T_propagation_history > traversed;

  const int n_scat_min_; ///< Min number of scatterings to register
  const int  n_scat_max_; ///< Max number of scatterings to register
  /** True if intensity information should be registered in the
      grid. */
  bool add_intensities_;
  /** True if absorption events should be treated as "effective
      scatterings" (this must be supported by the data type in the
      grid), otherwise the albedo is applied to the ray intensity at
      scattering events. */
  bool immediate_reemission_;

  /** The local_pending_ queue holds rays that are waiting to be
      processed by this xfer instance. This includes rays that have
      been split locally as well as rays that are transfered onto this
      task from other MPI tasks. */
  mutable boost::shared_ptr<T_ray_queue> local_pending_;
  mpi_master<xfer>* mpi_master_;

  /// Random number generator policy is used for getting random numbers. 
  T_rng_policy rng_;

  /// The thread number.
  int thread_num_;
  
  /** The pixel queue holds pixel positions of integration rays
      waiting to be run. */
  mutable boost::shared_ptr<T_pixel_queue> px_queue_;
  /** The number of integration rays per cell. Determines the ray
      subdivision when integrating. */
  int integrations_per_cell_;
  
  void emit(bool, long=0);
  void ray_mainloop();
  void process_ray_forced();
  void process_ray_unforced();
  // Propagates the ray to the boundary of the grid cell. 
  template <bool do_scatter, bool keep_traversed, bool add_intensity, 
	    bool integrate_ray>
  bool propagate(T_float max_length= blitz::huge(T_float()));

  void scatter(int resume_with = -1);   // Scatters ray.
  bool russian_roulette(T_float=1);
  // Calculates emerging intensity from current ray position.
  void emerge_from_point (const angular_distribution<T_lambda>&, 
			  const vec3d&, bool);
  void camera_ray (int, T_float=-1);
  void integrate_ray(int, bool, T_float = -1);
  bool subdivide_integration_ray(int, T_float, T_float, int, bool);
  vec3d px2dir(int cam, int px, int level) const;

  // Propagates an external ray to the grid boundary.
  void propagate_external (T_float max_length = blitz::huge(T_float())); 

  /** Adds the ray to the image. */
  template <typename T>
  void add_to_camera(typename T_emergence::T_camera& c, const vec3d& pos, 
		     const T& intensity, T_float doppler) {
    if ((ray_.scatterings()<=n_scat_max_)&&(ray_.scatterings()>=n_scat_min_))
      c.add(pos, intensity, doppler);
  };

  void add_intensity (const T_cell_tracker&, T_float);
  template <typename T>
  void add_intensity (const T_cell_tracker&, T_float, const blitz::ETBase<T>& intensity);

  void check_ray_split(T_action, int=-1);
  void push_ray(T_action, const T_ray&, const T_cell_tracker&, 
		T_lambda&, int par=-1);
  T_action pop_ray(int&);

  /** This function consolidates all the cache arrays to point to a
      contiguous piece of memory. It's a no-op except for
      polychromatic runs which is called with an Array. */
  template<typename TT>
  void consolidate_arrays(TT) {};

  /** This function consolidates all the cache arrays to point to a
      contiguous piece of memory. */
  template<typename TT>
  void consolidate_arrays(blitz::Array<TT, 1>&);

public:
  xfer (/// The grid in which ray tracing is to be done
	T_grid&,
	/// The emission object.
	const T_emission&,
	/// The emergence object.
	T_emergence&,
	/// The dust_model object.
	const T_dust_model&,
	/// Seed for the random number generator.
	T_rng_policy& rng,
	/** The biaser to use. */
	const T_biaser& b,
	/// The lowest ray intensity before Russian Roulette starts.
	T_float i_min=1e-4,
	/// The maximum ray intensity allowed.
	T_float i_max=1e2,
	/** If true, radiative intensities are added to the grid. */
	bool add_intensities=true,
	/// If true, "effective scattering" is used instead of absorption.
	bool immediate_reemission=false,
	/// The minimum number of scatterings to register in the cameras.
	int minscat=0,
	/// The maximum number of scatterings to register in the cameras.
	int maxscat=blitz::huge(int()));
  xfer (const xfer& rhs);
  ~xfer ();
  
  T_rng_policy& rng () {return rng_;};

  T_grid& grid() const { return g; };
  const T_emission& emission() const { return emission_; };
  const T_dust_model& model() const { return dust_model; };

  // Transfers one ray until it exits the volume or is absorbed
  void shoot();
  void shoot_isotropic(long=0);

  void worker_loop(tbb::atomic<long>&, long, 
		   boost::function<bool (xfer*, tbb::atomic<long>&, long)>);

  bool create_work_shooting(tbb::atomic<long>&, long);
  bool create_work_intensity_integration(tbb::atomic<long>&, long);
  bool create_work_emission_integration(tbb::atomic<long>&, long);

  void init_ray_integration();
  void integrate_ray(const blitz::TinyVector<int, 2>& px, 
		     int c, bool emission_integration);

  T_densities integrate_column_density(const vec3d& pos, const vec3d& dir);

  /** Sets the number of integrations per cell, which determines how
      much the integration rays subdivide. */
  void set_integrations_per_cell(int i) {integrations_per_cell_ = i; };

  boost::shared_ptr<T_ray_queue> local_queue() const {return local_pending_; };
  boost::shared_ptr<T_pixel_queue> pixel_queue() const {return px_queue_; };
  void set_thread_number(int t) { thread_num_ = t; };
  void set_master(mpi_master<xfer>* m) { mpi_master_=m; };
  mpi_master<xfer>& mpi_master() { return *mpi_master_; };
  void init_queue_item(T_queue_item&);

  /// Returns the local MPI task number.
  int task() const { return mpi_rank(); };
  /// Returns the local thread number.
  int thread() const { return thread_num_; };
};

#endif


